from sklearn.cluster import KMeans
import numpy as np
from scipy.spatial.distance import cdist, euclidean
import time
from numba import jit

class K_median_alg:
    def __init__(self, k, max_iter, n_init):
        self.k = k
        self.centers = None
        self.max_iter = max_iter
        self.center_full = None
        self.centers_subset = None
        self.history = []
        self.kmeans = KMeans(n_clusters=k)
        
    def get_centers_full(self):
        assert self.center_full is not None
        return self.center_full
    def get_centers_subset(self):
        assert self.centers_subset is not None
        return self.centers_subset
    
    def get_original_cost(self, data):
        best_score = None
        best_centers = None
        for i in range(5):
            centers = self.fit(data)
            score = self.get_cost_for_centers(data, centers)
            if best_score is None or best_score>score:
                best_score = score
                best_centers = centers
        self.center_full = best_centers
        cost = self.get_cost_for_centers(data, best_centers)
        return cost
    
    def get_subset_and_original(self, data_full, data_subset):
        centers = self.fit(data_subset)
        self.centers_subset = centers
        cost_subset = self.get_cost_for_centers(data_subset, centers)
        cost_original = self.get_cost_for_centers(data_full, centers)
        
        return cost_subset, cost_original
    
    def get_cost_for_centers(self, data, centers):
        return total_cost_kmedian(data, centers)

    """check if cost is not decreasing for 5 iterations"""
    def check_convergence(self):
        if len(self.history) < 5:
            return False
        for i in range(1, 5):
            if self.history[-1]<0.99*self.history[-i]:
                return False
        return True
            
    def center_seeding(self, data, k):
        return center_seeding_fast(data, k)
    
    def all_assigned(self, clusters):
        """check if all values of k are in clusters"""
        for i in range(self.k):
            if i not in clusters:
                print("cluster ", i, " is empty")
                return False
        return True
    
    def reassign_empty_clusters(self, data, clusters, centers):
        """reassign empty clusters to new points"""
        centers = np.nan_to_num(centers, nan=-100)
        for i in range(self.k):
            if i not in clusters:
                max_dist_point = None
                max_dist = 0
                for j in range(data.shape[0]):
                    dist = np.sqrt(np.min(np.sum((data[j]-centers)**2, axis=1)))
                    if dist>max_dist:
                        max_dist = dist
                        max_dist_point = data[j]
                centers[i] = max_dist_point
        return centers
    
    def fit(self, data):
        initial_centers = self.center_seeding(data, self.k)
        print("fitting kmedian")
        self.history = []
        for i in range(self.max_iter):
            assert(np.isnan(initial_centers).sum() == 0)
            if i%10 == 0 & i>0:
                print("reassignment iteration: ", i)
            if(self.check_convergence()):
                break
            asssigned_clusters = self.reassign_points(data, initial_centers)
            if not self.all_assigned(asssigned_clusters):
                while(not self.all_assigned(asssigned_clusters)):
                    print("reassigning empty clusters")
                    initial_centers = self.reassign_empty_clusters(data, asssigned_clusters, initial_centers)
                    asssigned_clusters = self.reassign_points(data, initial_centers)
            for j in range(self.k):
                cluster_points = data[asssigned_clusters == j]
                initial_centers[j] = self.calculate_geometric_median(cluster_points)
                assert(np.isnan(initial_centers[j]).sum() == 0)
            self.history.append(self.get_cost_for_centers(data, initial_centers))
        self.centers = initial_centers
        return self.centers
        
    def reassign_points(self, data, centers):
        return reassign_points_faster(data, centers)
    
    def calculate_geometric_median(self, X):
        return calculate_geometric_median_faster(X)
    
def calculate_geometric_median_faster(X, eps=1e-5):
    y = np.mean(X, 0)
    assert (X.shape[0] > 0)
    while True:
        assert(np.isnan(y).sum() == 0)
        D = cdist(X, [y])
        # assert(D==D1)
        nonzeros = (D != 0)[:, 0]

        Dinv = 1 / D[nonzeros]
        Dinvs = np.sum(Dinv)
        W = Dinv / Dinvs
        T = np.sum(W * X[nonzeros], 0)

        num_zeros = len(X) - np.sum(nonzeros)
        if num_zeros == 0:
            y1 = T
        elif num_zeros == len(X):
            assert(np.isnan(y).sum() == 0)
            return y
        else:
            R = (T - y) * Dinvs
            r = np.linalg.norm(R)
            rinv = 0 if r == 0 else num_zeros/r
            y1 = max(0, 1-rinv)*T + min(1, rinv)*y

        if euclidean(y, y1) < eps:
            assert(np.isnan(y1).sum() == 0)
            return y1

        y = y1
@jit(nopython=True)
def fast_distance(p,q):
    diff = p-q
    return np.sqrt(np.sum(diff*diff, axis=1))
    
@jit(nopython=True)
def reassign_points_faster(data, centers):
    clusters = np.zeros(data.shape[0])  
    for i in range(data.shape[0]):
        diff = data[i]-centers
        clusters[i] = np.argmin(np.sum(diff*diff, axis=1))
    return clusters

@jit(nopython=True)
def total_cost_kmedian(data, centers):
    cost = 0
    for i in range(data.shape[0]):
        diff = data[i]-centers
        cost += np.sqrt(np.min(np.sum(diff*diff, axis=1)))
    return cost

def center_seeding_fast(data, k):
    centers = np.zeros((k, data.shape[1]))-1
    centers[0] = data[np.random.randint(data.shape[0])]
    for i in range(1, k):
        dists = dist_to_centers(data, centers[:i])
        centers[i] = data[np.random.choice(data.shape[0], p=dists/dists.sum())]
    return centers

@jit(nopython=True)
def dist_to_centers(data, centers):
    dists = np.zeros(data.shape[0])
    for i in range(data.shape[0]):
        diff = data[i]-centers
        dists[i] = np.sqrt(np.min(np.sum(diff*diff, axis=1)))
    return dists